%% original Author: Hari ThiruMoorthy
%  Final edited Author: MZ
%  Copyright 2021. All Rights Reserved
function [llr, dec_msg] = msk_dem_logmap(dat_rx, TREL, llr_in, rz_fg)

% load 'logmap_test';


ni = TREL.numInputSymbols;
no = TREL.numOutputSymbols;
ns = TREL.numStates;
sps = TREL.OverSampling;
ml = log2(ns);

len_s = length(dat_rx);
len_m = len_s / sps;

if len_m~=floor(len_m)
    error('ERROR: Make length(RX) = multiple of sps!')
end
if log2(ni) ~= 1
    error('ERROR: This Function only works for EncoderK=1!')
end

llr_a(2,:) = -log(1+exp(llr_in));
llr_a(1,:) = llr_in + llr_a(2,:) ;


%%	Compute branch-metrics gamma*(m', m) fev branch fev time using segments of RX. There
%%	are numOutputSymbols distinct branches. col-i contains all possible
%%	branch-metrics at time i, i= 1,...,len_m.
BranMetric = NaN(ns, ni, len_m);     %%%%%%%%%计算在每一时刻的numOutSymbols（所有可能输出）个gama，一共有numOutSymbols个输出，因此有numOutSymbols（所有可能输出）个rama
for time = 1 : len_m
   rxvec = dat_rx((time-1)*sps +1 : time*sps); % [1, N], [N+1, 2*N] , ...
    for ii=1:ns
        for jj=1:ni
            out = TREL.outputs(jj, 1:sps, ii); % All possible binary-vectors of length sps
            BranMetric(ii, jj, time) = real(rxvec * out') + llr_a(jj,time); %metric at "time" for branch = binary-representation(BranIdx-1)
        end
    end
end
%%%%%%%%每一列对应这每一时刻的numOutSymbols（所有可能输出）个gama

%% Alpha*(t, m) = sum_{m'=0,...,NS-1} [ Alpha*(t-1, m') * gamma*(m', m)]
%% Initialize Alpha*(time=0) as below, since TX ensures encoder starts trellis at state=0 at time t=0.
Alpha = zeros(len_m +1, ns);
% Alpha(1, 1:end) = -1e9;
% Alpha(1, final_state+1) = 0;
Alpha(1, 1) = 0;
Alpha(1, 2:end) = -1e9;

for time = 1 : len_m
%     rx_seg = dat_rx((time-1)*sps +1 : time*sps);
    for dest_st = 0 : ns-1
        sum1 = -1e9;
        for idx = 1 : length( TREL.prevStates(dest_st+1, :) )
            orig_st = TREL.prevStates(dest_st+1, idx);            
            bit_in = TREL.prevStateIn(dest_st+1, idx);
            sum1 = ln_of_sum_of_exps( sum1, Alpha(time -1 +1, orig_st +1) + BranMetric(orig_st+1, bit_in+1, time)  );
        end
        Alpha(time+1, dest_st+1) = sum1;
    end
end     %%%%%%%%%%%%%%%%查询每个时刻time的可达状态，使得可达状态的alpha值不为-1e9,不可达状态为-1e9；由其前一状态到达该可达状态所对应的输出计算branmetric得Alpha

%% Beta*(t-1, m) = sum_{m'=0,...,NS-1} [ Beta*(t, m') * gamma*(m', m)]
%% Initialize Beta*(time= len_m) as below, since TX ensures encoder terminates trellis at state=0 at time t=len_m
Beta = zeros(len_m +1, ns);
% Beta(len_m +1, 1:end) = -1e9;
% Beta(len_m +1, final_state+1) = 0;
if (rz_fg)
    Beta(len_m +1, 1) = 0;
    Beta(len_m +1, 2:end) = -1e9;
end

for time = (len_m -1) : -1 : 0
%     rx_seg = dat_rx(time*sps +1 : (time+1)*sps);
    for orig_st = 0 : ns-1
        sum2 = -1e9;
        for idx = 1 : length( TREL.nextStates(orig_st+1, :) )
            dest_st = TREL.nextStates(orig_st+1, idx);
            bit_in = idx - 1;
            sum2 = ln_of_sum_of_exps( sum2, Beta(time +2, dest_st +1) + BranMetric(orig_st+1, bit_in+1, time+1) );
        end
        Beta(time+1, orig_st+1) = sum2;
    end
end
 %%%%%%%%%%%%%%%%%%%查询每个时刻time的可达状态，使得可达状态的Beta值不为-1e9，不可达状态为-1e9.；由其后一状态到达
 %%%%%%%%%%%%%%%%%%%可达状态所对应的输出计算branmetric得Alpha；
%% 计算译码比特软信息
%% L(U(k)|Y) = ln( . / .)
llr = zeros(1,len_m);
for time = len_m : -1 : 1
%% prob of all transitions arising from i/p bit = 0. Works only for EncoderK = 1. 
    total_prob_bit0 = -1e9;
    for orig_st = 0 : ns-1
        dest_st = TREL.nextStates(orig_st+1, 1);
        incr_prob = Alpha(time, orig_st+1) + Beta(time+1, dest_st+1) + BranMetric(orig_st+1, 1, time);
        total_prob_bit0 = ln_of_sum_of_exps( total_prob_bit0, incr_prob);
    end

%% prob of all transitions arising from i/p bit = 1. Works only for EncoderK = 1. 
    total_prob_bit1 = -1e9;
    for orig_st = 0 : ns-1
        dest_st = TREL.nextStates(orig_st+1, 2);
        incr_prob = Alpha(time, orig_st+1) + Beta(time+1, dest_st+1) + BranMetric(orig_st+1, 2, time);
        total_prob_bit1 = ln_of_sum_of_exps( total_prob_bit1,  incr_prob);
    end

    llr(time) = total_prob_bit0 - total_prob_bit1;
end

dec_msg = (1 - sign(llr)) / 2;

return;
%%=========================================================================
 

%%	Function ln( e^x + e^y) = max(x, y) + ln(1 + e^-|x-y|)
function a = ln_of_sum_of_exps(x, y)
	a = max(x, y) + log( 1 + exp(-abs(x - y)) );
%     a = max(x, y);
return;

